{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diatom analysis\n",
    "\n",
    "See https://www.nature.com/articles/s41524-019-0202-3:\n",
    "\n",
    "**Deep data analytics for genetic engineering of diatoms linking genotype to phenotype via machine learning**, Artem A. Trofimov, Alison A. Pawlicki, Nikolay Borodinov, Shovon Mandal, Teresa J. Mathews, Mark Hildebrand, Maxim A. Ziatdinov, Katherine A. Hausladen, Paulina K. Urbanowicz, Chad A. Steed, Anton V. Ievlev, Alex Belianinov, Joshua K. Michener, Rama Vasudevan, and Olga S. Ovchinnikova."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up matplotlib defaults: larger images, gray color map\n",
    "import matplotlib\n",
    "matplotlib.rcParams.update({\n",
    "    'figure.figsize': (10, 10),\n",
    "    'image.cmap': 'gray'\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage import io\n",
    "image = io.imread('../data/diatom-wild-032.jpg')\n",
    "\n",
    "plt.imshow(image);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pores = image[:690, :]\n",
    "\n",
    "plt.imshow(pores);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import ndimage as ndi\n",
    "from skimage import util\n",
    "\n",
    "denoised = ndi.median_filter(util.img_as_float(pores), size=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(denoised);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage import exposure\n",
    "\n",
    "pores_gamma = exposure.adjust_gamma(denoised, 0.7)\n",
    "plt.imshow(pores_gamma);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pores_inv = 1 - pores_gamma\n",
    "plt.imshow(pores_inv);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is the problematic part of the manual pipeline: you need\n",
    "# a good segmentation.  There are algorithms for automatic thresholding,\n",
    "# such as `filters.otsu` and `filters.li`, but they don't always get the\n",
    "# result you want.\n",
    "\n",
    "t = 0.325\n",
    "thresholded = (pores_gamma <= t)\n",
    "\n",
    "plt.imshow(thresholded);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage import filters\n",
    "\n",
    "filters.try_all_threshold(pores_gamma, figsize=(15, 20));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage import segmentation, morphology, color"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distance = ndi.distance_transform_edt(thresholded)\n",
    "\n",
    "plt.imshow(exposure.adjust_gamma(distance, 0.5))\n",
    "plt.title('Distance to background map');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "local_maxima = morphology.local_maxima(distance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(20, 20))\n",
    "\n",
    "maxi_coords = np.nonzero(local_maxima)\n",
    "\n",
    "ax.imshow(pores);\n",
    "plt.scatter(maxi_coords[1], maxi_coords[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a utility function that we'll use for display in a while;\n",
    "# you can ignore it for now and come and investigate later.\n",
    "\n",
    "def shuffle_labels(labels):\n",
    "    \"\"\"Shuffle the labels so that they are no longer in order.\n",
    "    This helps with visualization.\n",
    "    \"\"\"\n",
    "    indices = np.unique(labels[labels != 0])\n",
    "    indices = np.append(\n",
    "        [0],\n",
    "        np.random.permutation(indices)\n",
    "    )\n",
    "    return indices[labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "markers = ndi.label(local_maxima)[0]\n",
    "labels = segmentation.watershed(denoised, markers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))\n",
    "ax0.imshow(thresholded)\n",
    "ax1.imshow(np.log(1 + distance))\n",
    "ax2.imshow(shuffle_labels(labels), cmap='magma');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_masked = segmentation.watershed(thresholded, markers, mask=thresholded, connectivity=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))\n",
    "ax0.imshow(thresholded)\n",
    "ax1.imshow(np.log(1 + distance))\n",
    "ax2.imshow(shuffle_labels(labels_masked), cmap='magma');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage import measure\n",
    "\n",
    "contours = measure.find_contours(labels_masked, level=0.5)\n",
    "plt.imshow(pores)\n",
    "for c in contours:\n",
    "    plt.plot(c[:, 1], c[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regions = measure.regionprops(labels_masked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(figsize=(10, 3))\n",
    "ax.hist([r.area for r in regions], bins=100, range=(0, 200));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import models, layers\n",
    "from keras.layers import Conv2D, MaxPooling2D, UpSampling2D\n",
    "\n",
    "M = 76\n",
    "N = int(23 / 76 * M) * 2\n",
    "\n",
    "model = models.Sequential()\n",
    "model.add(\n",
    "    Conv2D(\n",
    "        32,\n",
    "        kernel_size=(2, 2),\n",
    "        activation='relu',\n",
    "        input_shape=(N, N, 1),\n",
    "        padding='same'\n",
    "    )\n",
    ")\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
    "model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
    "model.add(UpSampling2D(size=(2, 2)))\n",
    "model.add(\n",
    "    Conv2D(\n",
    "        1,\n",
    "        kernel_size=(2, 2),\n",
    "        activation='sigmoid',\n",
    "        padding='same'\n",
    "    )\n",
    ")\n",
    "model.compile(loss='mse', optimizer='Adam', metrics=['accuracy'])\n",
    "\n",
    "# Load pre-trained weights from disk\n",
    "model.load_weights('../data/keras_model-diatoms-pores.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = np.array(pores.shape)\n",
    "padded_shape = (np.ceil(shape / 46) * 46).astype(int)\n",
    "delta_shape = padded_shape - shape\n",
    "\n",
    "padded_pores = np.pad(\n",
    "    pores,\n",
    "    pad_width=[(0, delta_shape[0]), (0, delta_shape[1])],\n",
    "    mode='symmetric'\n",
    ")\n",
    "\n",
    "blocks = util.view_as_blocks(padded_pores, (46, 46))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "B_rows, B_cols, _, _ = blocks.shape\n",
    "\n",
    "tiles = blocks.reshape([-1, 46, 46])\n",
    "\n",
    "# `predict` wants input of shape (N, 46, 46, 1)\n",
    "tile_masks = model.predict_classes(tiles[..., np.newaxis])\n",
    "\n",
    "print(tile_masks.shape)\n",
    "tile_masks = tile_masks[..., 0].astype(bool)\n",
    "print(tile_masks.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_mask = util.montage(tile_masks, grid_shape=(B_rows, B_cols))\n",
    "nn_mask = nn_mask[:shape[0], :shape[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(nn_mask);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "contours = measure.find_contours(nn_mask, level=0.5)\n",
    "plt.imshow(pores)\n",
    "for c in contours:\n",
    "    plt.plot(c[:, 1], c[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_regions = measure.regionprops(\n",
    "    measure.label(nn_mask)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(figsize=(10, 3))\n",
    "ax.hist([r.area for r in regions], bins='auto', range=(0, 200), alpha=0.4, label='Classic')\n",
    "ax.hist([r.area for r in nn_regions], bins='auto', range=(0, 200), alpha=0.4, label='NN')\n",
    "ax.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus round: region filtering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_circular(regions, eccentricity_threshold=0.1, area_threshold=10):\n",
    "    \"\"\"Calculate a boolean mask indicating which regions are circular.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    eccentricity_threshold : float, >= 0\n",
    "        Regions with an eccentricity less than than this value are\n",
    "        considered circular. See `measure.regionprops`.\n",
    "    area_threshold : int\n",
    "        Only regions with an area greater than this value are considered\n",
    "        circular.\n",
    "    \"\"\"\n",
    "    return np.array([\n",
    "        (r.area > area_threshold) and\n",
    "        (r.eccentricity <= eccentricity_threshold)\n",
    "        for r in regions\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filtered_mask(mask, regions, eccentricity_threshold, area_threshold):\n",
    "    mask = mask.copy()\n",
    "    suppress_regions = np.array(regions)[\n",
    "        ~is_circular(\n",
    "            regions,\n",
    "            eccentricity_threshold=eccentricity_threshold,\n",
    "            area_threshold=area_threshold\n",
    "        )\n",
    "    ]\n",
    "    \n",
    "    for r in suppress_regions:\n",
    "        mask[tuple(r.coords.T)] = 0\n",
    "        \n",
    "    return mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(filtered_mask(nn_mask, nn_regions,\n",
    "                         eccentricity_threshold=0.8,\n",
    "                         area_threshold=20));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "contours = measure.find_contours(\n",
    "    filtered_mask(nn_mask, nn_regions,\n",
    "                  eccentricity_threshold=0.8,\n",
    "                  area_threshold=20),\n",
    "    level=0.5\n",
    ")\n",
    "plt.imshow(pores)\n",
    "for c in contours:\n",
    "    plt.plot(c[:, 1], c[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_regions = np.array(nn_regions)[is_circular(nn_regions, 0.8, 20)]\n",
    "\n",
    "f, ax = plt.subplots(figsize=(10, 3))\n",
    "ax.hist([r.area for r in filtered_regions], bins='auto', range=(0, 200), alpha=0.4);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
